package querymgr

import (
	"context"

	"github.com/pkg/errors"
	deploymentDataStore "github.com/stackrox/rox/central/deployment/datastore"
	imgDataStore "github.com/stackrox/rox/central/image/datastore"
	imgV2DataStore "github.com/stackrox/rox/central/imagev2/datastore"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/cache"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/common"
	vulnReqDataStore "github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/datastore"
	"github.com/stackrox/rox/central/vulnmgmt/vulnerabilityrequest/utils"
	v1 "github.com/stackrox/rox/generated/api/v1"
	"github.com/stackrox/rox/generated/storage"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stackrox/rox/pkg/sac/resources"
)

var (
	allAccessCtx           = sac.WithAllAccess(context.Background())
	requesterOrApproverSAC = sac.ForResources(sac.ForResource(resources.VulnerabilityManagementRequests), sac.ForResource(resources.VulnerabilityManagementApprovals))
)

type queryManagerImpl struct {
	deployments deploymentDataStore.DataStore
	images      imgDataStore.DataStore
	imageV2s    imgV2DataStore.DataStore
	vulnReqs    vulnReqDataStore.DataStore

	activeReqCache  cache.VulnReqCache
	pendingReqCache cache.VulnReqCache
}

func (m *queryManagerImpl) DeploymentCount(ctx context.Context, requestID string, query *v1.Query) (int, error) {
	query, err := m.getAffectedImagesQueryForVulnReq(ctx, requestID, query)
	if err != nil {
		return 0, err
	}
	return m.deployments.Count(ctx, query)
}

func (m *queryManagerImpl) ImageCount(ctx context.Context, requestID string, query *v1.Query) (int, error) {
	query, err := m.getAffectedImagesQueryForVulnReq(ctx, requestID, query)
	if err != nil {
		return 0, err
	}
	return m.images.Count(ctx, query)
}

func (m *queryManagerImpl) Deployments(ctx context.Context, requestID string, query *v1.Query) ([]*storage.Deployment, error) {
	query, err := m.getAffectedImagesQueryForVulnReq(ctx, requestID, query)
	if err != nil {
		return nil, err
	}
	return m.deployments.SearchRawDeployments(ctx, query)
}

func (m *queryManagerImpl) Images(ctx context.Context, requestID string, query *v1.Query) ([]*storage.Image, error) {
	query, err := m.getAffectedImagesQueryForVulnReq(ctx, requestID, query)
	if err != nil {
		return nil, err
	}
	return m.images.SearchRawImages(ctx, query)
}

func (m *queryManagerImpl) ImageV2s(ctx context.Context, requestID string, query *v1.Query) ([]*storage.ImageV2, error) {
	query, err := m.getAffectedImagesQueryForVulnReq(ctx, requestID, query)
	if err != nil {
		return nil, err
	}
	return m.imageV2s.SearchRawImages(ctx, query)
}

func (m *queryManagerImpl) VulnsWithState(ctx context.Context, scope common.VulnReqScope) (map[string]storage.VulnerabilityState, error) {
	if ok, err := requesterOrApproverSAC.ReadAllowedToAny(ctx); err != nil {
		return nil, err
	} else if !ok {
		return nil, sac.ErrResourceAccessDenied
	}
	return m.activeReqCache.GetVulnsWithState(scope.Registry, scope.Remote, scope.Tag), nil
}

func (m *queryManagerImpl) EffectiveVulnReq(ctx context.Context, cve string, scope common.VulnReqScope) (*storage.VulnerabilityRequest, error) {
	// First check if there is an active deferral vul req since that is THE vuln req in in-effect.
	id := m.activeReqCache.GetEffectiveVulnReqIDForImage(scope.Registry, scope.Remote, scope.Tag, cve)
	if id == "" {
		id = m.pendingReqCache.GetEffectiveVulnReqIDForImage(scope.Registry, scope.Remote, scope.Tag, cve)
	}
	if id == "" {
		return nil, nil
	}
	ret, found, err := m.vulnReqs.Get(ctx, id)
	if err != nil || !found {
		return nil, err
	}
	return ret, nil
}

func (m *queryManagerImpl) getAffectedImagesQueryForVulnReq(ctx context.Context, requestID string, query *v1.Query) (*v1.Query, error) {
	request, found, err := m.vulnReqs.Get(ctx, requestID)
	if err != nil {
		return nil, nil
	}
	if !found {
		return nil, errors.Errorf("vulnerability request %q not found", requestID)
	}
	return utils.GetAffectedImagesQuery(request, query)
}
